{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Author: Sebastian Raschka\n",
      "\n",
      "Python implementation: CPython\n",
      "Python version       : 3.8.8\n",
      "IPython version      : 7.21.0\n",
      "\n",
      "torch: 1.8.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Distribute a Model Across Multiple GPUs with Pipeline Parallelism (VGG-16 Example)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demos pipeline parallelism added to PyTorch 1.8 using VGG-16 as an example. For more details, see https://pytorch.org/docs/1.8.0/pipeline.html?highlight=pipeline#."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1) Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "\n",
    "sys.path.insert(0, \"..\") # to include ../helper_evaluate.py etc.\n",
    "\n",
    "from helper_utils import set_all_seeds, set_deterministic\n",
    "from helper_evaluate import compute_accuracy\n",
    "from helper_data import get_dataloaders_cifar10\n",
    "from helper_train import train_classifier_simple_v1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Data settings\n",
    "num_classes = 10\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.0001\n",
    "batch_size = 128\n",
    "num_epochs = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_all_seeds(random_seed)\n",
    "#set_deterministic()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### Dataset\n",
    "##########################\n",
    "\n",
    "train_loader, valid_loader, test_loader = get_dataloaders_cifar10(\n",
    "    batch_size, \n",
    "    num_workers=2, \n",
    "    validation_fraction=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2) Regular (1-GPU) Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This section implements the VGG-16 network in the conventional manne as a reference. The next section replicates this using pipeline parallelism."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### Model\n",
    "##########################\n",
    "\n",
    "\n",
    "class VGG16(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super().__init__()\n",
    "        \n",
    "        # calculate same padding:\n",
    "        # (w - k + 2*p)/s + 1 = o\n",
    "        # => p = (s(o-1) - w + k)/2\n",
    "        \n",
    "        self.block_1 = torch.nn.Sequential(\n",
    "                torch.nn.Conv2d(in_channels=3,\n",
    "                          out_channels=64,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          # (1(32-1)- 32 + 3)/2 = 1\n",
    "                          padding=1), \n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Conv2d(in_channels=64,\n",
    "                          out_channels=64,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "        self.block_2 = torch.nn.Sequential(\n",
    "                torch.nn.Conv2d(in_channels=64,\n",
    "                          out_channels=128,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Conv2d(in_channels=128,\n",
    "                          out_channels=128,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "        self.block_3 = torch.nn.Sequential(        \n",
    "                torch.nn.Conv2d(in_channels=128,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Conv2d(in_channels=256,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),        \n",
    "                torch.nn.Conv2d(in_channels=256,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "          \n",
    "        self.block_4 = torch.nn.Sequential(   \n",
    "                torch.nn.Conv2d(in_channels=256,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),        \n",
    "                torch.nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),        \n",
    "                torch.nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),            \n",
    "                torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "        self.block_5 = torch.nn.Sequential(\n",
    "                torch.nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),            \n",
    "                torch.nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),            \n",
    "                torch.nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                torch.nn.ReLU(),    \n",
    "                torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))             \n",
    "        )\n",
    "            \n",
    "        self.classifier = torch.nn.Sequential(\n",
    "            torch.nn.Flatten(),\n",
    "            torch.nn.Linear(512, 4096),\n",
    "            torch.nn.ReLU(True),\n",
    "            #torch.nn.Dropout(p=0.5),\n",
    "            torch.nn.Linear(4096, 4096),\n",
    "            torch.nn.ReLU(True),\n",
    "            #torch.nn.Dropout(p=0.5),\n",
    "            torch.nn.Linear(4096, num_classes),\n",
    "        )\n",
    "        \n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.block_1(x)\n",
    "        x = self.block_2(x)\n",
    "        x = self.block_3(x)\n",
    "        x = self.block_4(x)\n",
    "        x = self.block_5(x)\n",
    "        x = self.classifier(x) # logits\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "model = VGG16(num_classes=num_classes)\n",
    "\n",
    "device = torch.device('cuda:0')\n",
    "model = model.to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/050 | Batch 0000/0352 | Loss: 2.3034\n",
      "Epoch: 001/050 | Batch 0200/0352 | Loss: 1.9757\n",
      "***Epoch: 001/050 | Train. Acc.: 24.324% | Loss: 1.883\n",
      "***Epoch: 001/050 | Valid. Acc.: 24.480% | Loss: 1.873\n",
      "Time elapsed: 0.54 min\n",
      "Epoch: 002/050 | Batch 0000/0352 | Loss: 1.8859\n",
      "Epoch: 002/050 | Batch 0200/0352 | Loss: 1.8733\n",
      "***Epoch: 002/050 | Train. Acc.: 33.400% | Loss: 1.765\n",
      "***Epoch: 002/050 | Valid. Acc.: 33.280% | Loss: 1.781\n",
      "Time elapsed: 1.09 min\n",
      "Epoch: 003/050 | Batch 0000/0352 | Loss: 1.8217\n",
      "Epoch: 003/050 | Batch 0200/0352 | Loss: 1.5797\n",
      "***Epoch: 003/050 | Train. Acc.: 41.444% | Loss: 1.540\n",
      "***Epoch: 003/050 | Valid. Acc.: 40.460% | Loss: 1.549\n",
      "Time elapsed: 1.63 min\n",
      "Epoch: 004/050 | Batch 0000/0352 | Loss: 1.6600\n",
      "Epoch: 004/050 | Batch 0200/0352 | Loss: 1.4421\n",
      "***Epoch: 004/050 | Train. Acc.: 48.211% | Loss: 1.400\n",
      "***Epoch: 004/050 | Valid. Acc.: 47.960% | Loss: 1.418\n",
      "Time elapsed: 2.17 min\n",
      "Epoch: 005/050 | Batch 0000/0352 | Loss: 1.4752\n",
      "Epoch: 005/050 | Batch 0200/0352 | Loss: 1.3762\n",
      "***Epoch: 005/050 | Train. Acc.: 55.909% | Loss: 1.202\n",
      "***Epoch: 005/050 | Valid. Acc.: 54.940% | Loss: 1.231\n",
      "Time elapsed: 2.71 min\n",
      "Epoch: 006/050 | Batch 0000/0352 | Loss: 1.2081\n",
      "Epoch: 006/050 | Batch 0200/0352 | Loss: 1.1761\n",
      "***Epoch: 006/050 | Train. Acc.: 62.300% | Loss: 1.034\n",
      "***Epoch: 006/050 | Valid. Acc.: 60.620% | Loss: 1.069\n",
      "Time elapsed: 3.25 min\n",
      "Epoch: 007/050 | Batch 0000/0352 | Loss: 1.0605\n",
      "Epoch: 007/050 | Batch 0200/0352 | Loss: 0.9945\n",
      "***Epoch: 007/050 | Train. Acc.: 67.273% | Loss: 0.909\n",
      "***Epoch: 007/050 | Valid. Acc.: 65.260% | Loss: 0.970\n",
      "Time elapsed: 3.80 min\n",
      "Epoch: 008/050 | Batch 0000/0352 | Loss: 0.9049\n",
      "Epoch: 008/050 | Batch 0200/0352 | Loss: 0.8547\n",
      "***Epoch: 008/050 | Train. Acc.: 69.647% | Loss: 0.844\n",
      "***Epoch: 008/050 | Valid. Acc.: 66.620% | Loss: 0.939\n",
      "Time elapsed: 4.34 min\n",
      "Epoch: 009/050 | Batch 0000/0352 | Loss: 0.8435\n",
      "Epoch: 009/050 | Batch 0200/0352 | Loss: 0.8415\n",
      "***Epoch: 009/050 | Train. Acc.: 71.738% | Loss: 0.787\n",
      "***Epoch: 009/050 | Valid. Acc.: 68.700% | Loss: 0.912\n",
      "Time elapsed: 4.88 min\n",
      "Epoch: 010/050 | Batch 0000/0352 | Loss: 0.7540\n",
      "Epoch: 010/050 | Batch 0200/0352 | Loss: 0.7508\n",
      "***Epoch: 010/050 | Train. Acc.: 73.676% | Loss: 0.734\n",
      "***Epoch: 010/050 | Valid. Acc.: 69.380% | Loss: 0.896\n",
      "Time elapsed: 5.42 min\n",
      "Epoch: 011/050 | Batch 0000/0352 | Loss: 0.6758\n",
      "Epoch: 011/050 | Batch 0200/0352 | Loss: 0.7798\n",
      "***Epoch: 011/050 | Train. Acc.: 74.769% | Loss: 0.708\n",
      "***Epoch: 011/050 | Valid. Acc.: 69.540% | Loss: 0.912\n",
      "Time elapsed: 5.97 min\n",
      "Epoch: 012/050 | Batch 0000/0352 | Loss: 0.5798\n",
      "Epoch: 012/050 | Batch 0200/0352 | Loss: 0.6676\n",
      "***Epoch: 012/050 | Train. Acc.: 77.618% | Loss: 0.649\n",
      "***Epoch: 012/050 | Valid. Acc.: 71.120% | Loss: 0.864\n",
      "Time elapsed: 6.51 min\n",
      "Epoch: 013/050 | Batch 0000/0352 | Loss: 0.5435\n",
      "Epoch: 013/050 | Batch 0200/0352 | Loss: 0.5304\n",
      "***Epoch: 013/050 | Train. Acc.: 72.060% | Loss: 0.811\n",
      "***Epoch: 013/050 | Valid. Acc.: 66.560% | Loss: 1.064\n",
      "Time elapsed: 7.05 min\n",
      "Epoch: 014/050 | Batch 0000/0352 | Loss: 0.8212\n",
      "Epoch: 014/050 | Batch 0200/0352 | Loss: 0.5295\n",
      "***Epoch: 014/050 | Train. Acc.: 80.340% | Loss: 0.565\n",
      "***Epoch: 014/050 | Valid. Acc.: 72.520% | Loss: 0.868\n",
      "Time elapsed: 7.60 min\n",
      "Epoch: 015/050 | Batch 0000/0352 | Loss: 0.4467\n",
      "Epoch: 015/050 | Batch 0200/0352 | Loss: 0.5036\n",
      "***Epoch: 015/050 | Train. Acc.: 80.913% | Loss: 0.556\n",
      "***Epoch: 015/050 | Valid. Acc.: 72.140% | Loss: 0.930\n",
      "Time elapsed: 8.14 min\n",
      "Epoch: 016/050 | Batch 0000/0352 | Loss: 0.4458\n",
      "Epoch: 016/050 | Batch 0200/0352 | Loss: 0.5054\n",
      "***Epoch: 016/050 | Train. Acc.: 83.393% | Loss: 0.497\n",
      "***Epoch: 016/050 | Valid. Acc.: 73.360% | Loss: 0.933\n",
      "Time elapsed: 8.69 min\n",
      "Epoch: 017/050 | Batch 0000/0352 | Loss: 0.3353\n",
      "Epoch: 017/050 | Batch 0200/0352 | Loss: 0.5200\n",
      "***Epoch: 017/050 | Train. Acc.: 83.424% | Loss: 0.505\n",
      "***Epoch: 017/050 | Valid. Acc.: 72.440% | Loss: 0.975\n",
      "Time elapsed: 9.23 min\n",
      "Epoch: 018/050 | Batch 0000/0352 | Loss: 0.4049\n",
      "Epoch: 018/050 | Batch 0200/0352 | Loss: 0.4376\n",
      "***Epoch: 018/050 | Train. Acc.: 85.198% | Loss: 0.447\n",
      "***Epoch: 018/050 | Valid. Acc.: 73.160% | Loss: 0.965\n",
      "Time elapsed: 9.77 min\n",
      "Epoch: 019/050 | Batch 0000/0352 | Loss: 0.3085\n",
      "Epoch: 019/050 | Batch 0200/0352 | Loss: 0.3960\n",
      "***Epoch: 019/050 | Train. Acc.: 86.209% | Loss: 0.419\n",
      "***Epoch: 019/050 | Valid. Acc.: 73.140% | Loss: 1.017\n",
      "Time elapsed: 10.32 min\n",
      "Epoch: 020/050 | Batch 0000/0352 | Loss: 0.2945\n",
      "Epoch: 020/050 | Batch 0200/0352 | Loss: 0.3203\n",
      "***Epoch: 020/050 | Train. Acc.: 86.078% | Loss: 0.411\n",
      "***Epoch: 020/050 | Valid. Acc.: 72.480% | Loss: 1.011\n",
      "Time elapsed: 10.86 min\n",
      "Epoch: 021/050 | Batch 0000/0352 | Loss: 0.2950\n",
      "Epoch: 021/050 | Batch 0200/0352 | Loss: 0.3598\n",
      "***Epoch: 021/050 | Train. Acc.: 88.662% | Loss: 0.340\n",
      "***Epoch: 021/050 | Valid. Acc.: 73.460% | Loss: 1.048\n",
      "Time elapsed: 11.41 min\n",
      "Epoch: 022/050 | Batch 0000/0352 | Loss: 0.2688\n",
      "Epoch: 022/050 | Batch 0200/0352 | Loss: 0.3375\n",
      "***Epoch: 022/050 | Train. Acc.: 90.007% | Loss: 0.310\n",
      "***Epoch: 022/050 | Valid. Acc.: 74.300% | Loss: 1.081\n",
      "Time elapsed: 11.95 min\n",
      "Epoch: 023/050 | Batch 0000/0352 | Loss: 0.1888\n",
      "Epoch: 023/050 | Batch 0200/0352 | Loss: 0.2363\n",
      "***Epoch: 023/050 | Train. Acc.: 91.322% | Loss: 0.261\n",
      "***Epoch: 023/050 | Valid. Acc.: 75.540% | Loss: 1.055\n",
      "Time elapsed: 12.49 min\n",
      "Epoch: 024/050 | Batch 0000/0352 | Loss: 0.1915\n",
      "Epoch: 024/050 | Batch 0200/0352 | Loss: 0.1584\n",
      "***Epoch: 024/050 | Train. Acc.: 91.958% | Loss: 0.241\n",
      "***Epoch: 024/050 | Valid. Acc.: 76.000% | Loss: 1.057\n",
      "Time elapsed: 13.04 min\n",
      "Epoch: 025/050 | Batch 0000/0352 | Loss: 0.1950\n",
      "Epoch: 025/050 | Batch 0200/0352 | Loss: 0.1886\n",
      "***Epoch: 025/050 | Train. Acc.: 89.602% | Loss: 0.304\n",
      "***Epoch: 025/050 | Valid. Acc.: 74.600% | Loss: 1.166\n",
      "Time elapsed: 13.58 min\n",
      "Epoch: 026/050 | Batch 0000/0352 | Loss: 0.1521\n",
      "Epoch: 026/050 | Batch 0200/0352 | Loss: 0.3536\n",
      "***Epoch: 026/050 | Train. Acc.: 92.018% | Loss: 0.241\n",
      "***Epoch: 026/050 | Valid. Acc.: 74.700% | Loss: 1.186\n",
      "Time elapsed: 14.12 min\n",
      "Epoch: 027/050 | Batch 0000/0352 | Loss: 0.0479\n",
      "Epoch: 027/050 | Batch 0200/0352 | Loss: 0.1856\n",
      "***Epoch: 027/050 | Train. Acc.: 93.753% | Loss: 0.196\n",
      "***Epoch: 027/050 | Valid. Acc.: 75.320% | Loss: 1.220\n",
      "Time elapsed: 14.66 min\n",
      "Epoch: 028/050 | Batch 0000/0352 | Loss: 0.0642\n",
      "Epoch: 028/050 | Batch 0200/0352 | Loss: 0.0706\n",
      "***Epoch: 028/050 | Train. Acc.: 94.836% | Loss: 0.152\n",
      "***Epoch: 028/050 | Valid. Acc.: 75.320% | Loss: 1.163\n",
      "Time elapsed: 15.21 min\n",
      "Epoch: 029/050 | Batch 0000/0352 | Loss: 0.0688\n",
      "Epoch: 029/050 | Batch 0200/0352 | Loss: 0.2222\n",
      "***Epoch: 029/050 | Train. Acc.: 94.131% | Loss: 0.181\n",
      "***Epoch: 029/050 | Valid. Acc.: 75.580% | Loss: 1.224\n",
      "Time elapsed: 15.75 min\n",
      "Epoch: 030/050 | Batch 0000/0352 | Loss: 0.1569\n",
      "Epoch: 030/050 | Batch 0200/0352 | Loss: 0.0786\n",
      "***Epoch: 030/050 | Train. Acc.: 96.044% | Loss: 0.126\n",
      "***Epoch: 030/050 | Valid. Acc.: 76.240% | Loss: 1.272\n",
      "Time elapsed: 16.30 min\n",
      "Epoch: 031/050 | Batch 0000/0352 | Loss: 0.0563\n",
      "Epoch: 031/050 | Batch 0200/0352 | Loss: 0.0863\n",
      "***Epoch: 031/050 | Train. Acc.: 96.498% | Loss: 0.112\n",
      "***Epoch: 031/050 | Valid. Acc.: 76.520% | Loss: 1.239\n",
      "Time elapsed: 16.84 min\n",
      "Epoch: 032/050 | Batch 0000/0352 | Loss: 0.0244\n",
      "Epoch: 032/050 | Batch 0200/0352 | Loss: 0.1219\n",
      "***Epoch: 032/050 | Train. Acc.: 96.169% | Loss: 0.121\n",
      "***Epoch: 032/050 | Valid. Acc.: 76.720% | Loss: 1.216\n",
      "Time elapsed: 17.38 min\n",
      "Epoch: 033/050 | Batch 0000/0352 | Loss: 0.1291\n",
      "Epoch: 033/050 | Batch 0200/0352 | Loss: 0.0828\n",
      "***Epoch: 033/050 | Train. Acc.: 95.909% | Loss: 0.136\n",
      "***Epoch: 033/050 | Valid. Acc.: 76.000% | Loss: 1.242\n",
      "Time elapsed: 17.93 min\n",
      "Epoch: 034/050 | Batch 0000/0352 | Loss: 0.0858\n",
      "Epoch: 034/050 | Batch 0200/0352 | Loss: 0.1213\n",
      "***Epoch: 034/050 | Train. Acc.: 96.098% | Loss: 0.124\n",
      "***Epoch: 034/050 | Valid. Acc.: 75.720% | Loss: 1.318\n",
      "Time elapsed: 18.47 min\n",
      "Epoch: 035/050 | Batch 0000/0352 | Loss: 0.0288\n",
      "Epoch: 035/050 | Batch 0200/0352 | Loss: 0.0558\n",
      "***Epoch: 035/050 | Train. Acc.: 95.222% | Loss: 0.162\n",
      "***Epoch: 035/050 | Valid. Acc.: 75.160% | Loss: 1.388\n",
      "Time elapsed: 19.01 min\n",
      "Epoch: 036/050 | Batch 0000/0352 | Loss: 0.0136\n",
      "Epoch: 036/050 | Batch 0200/0352 | Loss: 0.0345\n",
      "***Epoch: 036/050 | Train. Acc.: 95.422% | Loss: 0.151\n",
      "***Epoch: 036/050 | Valid. Acc.: 74.820% | Loss: 1.354\n",
      "Time elapsed: 19.55 min\n",
      "Epoch: 037/050 | Batch 0000/0352 | Loss: 0.0644\n",
      "Epoch: 037/050 | Batch 0200/0352 | Loss: 0.0265\n",
      "***Epoch: 037/050 | Train. Acc.: 94.347% | Loss: 0.187\n",
      "***Epoch: 037/050 | Valid. Acc.: 74.560% | Loss: 1.423\n",
      "Time elapsed: 20.10 min\n",
      "Epoch: 038/050 | Batch 0000/0352 | Loss: 0.0368\n",
      "Epoch: 038/050 | Batch 0200/0352 | Loss: 0.0925\n",
      "***Epoch: 038/050 | Train. Acc.: 95.807% | Loss: 0.135\n",
      "***Epoch: 038/050 | Valid. Acc.: 75.380% | Loss: 1.365\n",
      "Time elapsed: 20.64 min\n",
      "Epoch: 039/050 | Batch 0000/0352 | Loss: 0.0510\n",
      "Epoch: 039/050 | Batch 0200/0352 | Loss: 0.2865\n",
      "***Epoch: 039/050 | Train. Acc.: 96.160% | Loss: 0.125\n",
      "***Epoch: 039/050 | Valid. Acc.: 75.940% | Loss: 1.406\n",
      "Time elapsed: 21.18 min\n",
      "Epoch: 040/050 | Batch 0000/0352 | Loss: 0.1398\n",
      "Epoch: 040/050 | Batch 0200/0352 | Loss: 0.0894\n",
      "***Epoch: 040/050 | Train. Acc.: 98.256% | Loss: 0.056\n",
      "***Epoch: 040/050 | Valid. Acc.: 77.300% | Loss: 1.251\n",
      "Time elapsed: 21.73 min\n",
      "Epoch: 041/050 | Batch 0000/0352 | Loss: 0.0201\n",
      "Epoch: 041/050 | Batch 0200/0352 | Loss: 0.0795\n",
      "***Epoch: 041/050 | Train. Acc.: 97.827% | Loss: 0.071\n",
      "***Epoch: 041/050 | Valid. Acc.: 77.000% | Loss: 1.379\n",
      "Time elapsed: 22.27 min\n",
      "Epoch: 042/050 | Batch 0000/0352 | Loss: 0.0309\n",
      "Epoch: 042/050 | Batch 0200/0352 | Loss: 0.0177\n",
      "***Epoch: 042/050 | Train. Acc.: 98.500% | Loss: 0.049\n",
      "***Epoch: 042/050 | Valid. Acc.: 77.280% | Loss: 1.315\n",
      "Time elapsed: 22.81 min\n",
      "Epoch: 043/050 | Batch 0000/0352 | Loss: 0.0442\n",
      "Epoch: 043/050 | Batch 0200/0352 | Loss: 0.0547\n",
      "***Epoch: 043/050 | Train. Acc.: 98.369% | Loss: 0.051\n",
      "***Epoch: 043/050 | Valid. Acc.: 77.680% | Loss: 1.303\n",
      "Time elapsed: 23.36 min\n",
      "Epoch: 044/050 | Batch 0000/0352 | Loss: 0.0294\n",
      "Epoch: 044/050 | Batch 0200/0352 | Loss: 0.0188\n",
      "***Epoch: 044/050 | Train. Acc.: 97.716% | Loss: 0.073\n",
      "***Epoch: 044/050 | Valid. Acc.: 77.100% | Loss: 1.381\n",
      "Time elapsed: 23.90 min\n",
      "Epoch: 045/050 | Batch 0000/0352 | Loss: 0.0226\n",
      "Epoch: 045/050 | Batch 0200/0352 | Loss: 0.0128\n",
      "***Epoch: 045/050 | Train. Acc.: 98.140% | Loss: 0.059\n",
      "***Epoch: 045/050 | Valid. Acc.: 76.420% | Loss: 1.395\n",
      "Time elapsed: 24.45 min\n",
      "Epoch: 046/050 | Batch 0000/0352 | Loss: 0.0170\n",
      "Epoch: 046/050 | Batch 0200/0352 | Loss: 0.0198\n",
      "***Epoch: 046/050 | Train. Acc.: 98.527% | Loss: 0.047\n",
      "***Epoch: 046/050 | Valid. Acc.: 77.000% | Loss: 1.548\n",
      "Time elapsed: 24.99 min\n",
      "Epoch: 047/050 | Batch 0000/0352 | Loss: 0.0011\n",
      "Epoch: 047/050 | Batch 0200/0352 | Loss: 0.0362\n",
      "***Epoch: 047/050 | Train. Acc.: 98.682% | Loss: 0.041\n",
      "***Epoch: 047/050 | Valid. Acc.: 77.260% | Loss: 1.345\n",
      "Time elapsed: 25.54 min\n",
      "Epoch: 048/050 | Batch 0000/0352 | Loss: 0.0180\n",
      "Epoch: 048/050 | Batch 0200/0352 | Loss: 0.0191\n",
      "***Epoch: 048/050 | Train. Acc.: 98.838% | Loss: 0.036\n",
      "***Epoch: 048/050 | Valid. Acc.: 77.300% | Loss: 1.412\n",
      "Time elapsed: 26.08 min\n",
      "Epoch: 049/050 | Batch 0000/0352 | Loss: 0.0114\n",
      "Epoch: 049/050 | Batch 0200/0352 | Loss: 0.0217\n",
      "***Epoch: 049/050 | Train. Acc.: 98.378% | Loss: 0.054\n",
      "***Epoch: 049/050 | Valid. Acc.: 77.460% | Loss: 1.459\n",
      "Time elapsed: 26.63 min\n",
      "Epoch: 050/050 | Batch 0000/0352 | Loss: 0.0044\n",
      "Epoch: 050/050 | Batch 0200/0352 | Loss: 0.0069\n",
      "***Epoch: 050/050 | Train. Acc.: 98.804% | Loss: 0.036\n",
      "***Epoch: 050/050 | Valid. Acc.: 77.700% | Loss: 1.365\n",
      "Time elapsed: 27.17 min\n",
      "Total Training Time: 27.17 min\n"
     ]
    }
   ],
   "source": [
    "_ = train_classifier_simple_v1(num_epochs=num_epochs, model=model, \n",
    "                               optimizer=optimizer, device=device, \n",
    "                               train_loader=train_loader, valid_loader=valid_loader, \n",
    "                               logging_interval=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3) VGG16 with Pipeline Parallelism"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we first define the blocks we are going to wrap into the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "block_1 = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(in_channels=3,\n",
    "                  out_channels=64,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  # (1(32-1)- 32 + 3)/2 = 1\n",
    "                  padding=1), \n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Conv2d(in_channels=64,\n",
    "                  out_channels=64,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                     stride=(2, 2))\n",
    ")\n",
    "\n",
    "block_2 = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(in_channels=64,\n",
    "                  out_channels=128,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Conv2d(in_channels=128,\n",
    "                  out_channels=128,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                     stride=(2, 2))\n",
    ")\n",
    "        \n",
    "block_3 = torch.nn.Sequential(        \n",
    "        torch.nn.Conv2d(in_channels=128,\n",
    "                  out_channels=256,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Conv2d(in_channels=256,\n",
    "                  out_channels=256,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),        \n",
    "        torch.nn.Conv2d(in_channels=256,\n",
    "                  out_channels=256,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                     stride=(2, 2))\n",
    ")\n",
    "        \n",
    "          \n",
    "block_4 = torch.nn.Sequential(   \n",
    "        torch.nn.Conv2d(in_channels=256,\n",
    "                  out_channels=512,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),        \n",
    "        torch.nn.Conv2d(in_channels=512,\n",
    "                  out_channels=512,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),        \n",
    "        torch.nn.Conv2d(in_channels=512,\n",
    "                  out_channels=512,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),            \n",
    "        torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                     stride=(2, 2))\n",
    ")\n",
    "        \n",
    "block_5 = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(in_channels=512,\n",
    "                  out_channels=512,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),            \n",
    "        torch.nn.Conv2d(in_channels=512,\n",
    "                  out_channels=512,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),            \n",
    "        torch.nn.Conv2d(in_channels=512,\n",
    "                  out_channels=512,\n",
    "                  kernel_size=(3, 3),\n",
    "                  stride=(1, 1),\n",
    "                  padding=1),\n",
    "        torch.nn.ReLU(),    \n",
    "        torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                     stride=(2, 2))             \n",
    ")\n",
    "            \n",
    "classifier = torch.nn.Sequential(\n",
    "    torch.nn.Flatten(),\n",
    "    torch.nn.Linear(512, 4096),\n",
    "    torch.nn.ReLU(True),\n",
    "    #torch.nn.Dropout(p=0.5),\n",
    "    torch.nn.Linear(4096, 4096),\n",
    "    torch.nn.ReLU(True),\n",
    "    #torch.nn.Dropout(p=0.5),\n",
    "    torch.nn.Linear(4096, num_classes),\n",
    ")\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before setting up the environment for the distributed run, we check if the distributed setting is supported on our machine. The following should return `True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.distributed.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we set the following environment variables for your machine:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For `MASTER_ADDR` just use the IP address of your machine. E.g., 123.45.67.89"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env MASTER_ADDR=xxx.xx.xx.xx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose a free port:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: MASTER_PORT=8891\n"
     ]
    }
   ],
   "source": [
    "%env MASTER_PORT=8891"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up the RPC if it is not already running (more details at https://pytorch.org/docs/stable/rpc.html):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    torch.distributed.rpc.init_rpc(name='node1', rank=0, world_size=1)\n",
    "except RuntimeError as e:\n",
    "    if str(e) == 'Address already in use':\n",
    "        pass\n",
    "    else:\n",
    "        raise RuntimeError(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the main part for running the model on multiple GPUs.\n",
    "\n",
    "1. We wrap the individual blocks into a `Sequential` model\n",
    "2. The chunks refer to the `microbatches`, for more details, see https://pytorch.org/docs/1.8.0/pipeline.html?highlight=pipeline#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.distributed.pipeline.sync import Pipe\n",
    "\n",
    "\n",
    "block1 = block_1.cuda(0)\n",
    "block2 = block_2.cuda(0)\n",
    "block3 = block_3.cuda(2)\n",
    "block4 = block_4.cuda(2)\n",
    "block4 = block_5.cuda(3)\n",
    "block4 = classifier.cuda(0)\n",
    "\n",
    "model_parallel = torch.nn.Sequential(\n",
    "    block_1, block_2, block_3, block_4, block_5, classifier)\n",
    "model_parallel = Pipe(model_parallel, chunks=8)\n",
    "optimizer = torch.optim.Adam(model_parallel.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/050 | Batch 0000/0352 | Loss: 2.3043\n",
      "Epoch: 001/050 | Batch 0200/0352 | Loss: 2.0182\n",
      "***Epoch: 001/050 | Train. Acc.: 21.711% | Loss: 1.914\n",
      "***Epoch: 001/050 | Valid. Acc.: 21.420% | Loss: 1.906\n",
      "Time elapsed: 1.35 min\n",
      "Epoch: 002/050 | Batch 0000/0352 | Loss: 1.9691\n",
      "Epoch: 002/050 | Batch 0200/0352 | Loss: 1.7432\n",
      "***Epoch: 002/050 | Train. Acc.: 28.180% | Loss: 1.881\n",
      "***Epoch: 002/050 | Valid. Acc.: 28.580% | Loss: 1.895\n",
      "Time elapsed: 2.68 min\n",
      "Epoch: 003/050 | Batch 0000/0352 | Loss: 1.8837\n",
      "Epoch: 003/050 | Batch 0200/0352 | Loss: 1.5357\n",
      "***Epoch: 003/050 | Train. Acc.: 38.922% | Loss: 1.584\n",
      "***Epoch: 003/050 | Valid. Acc.: 38.960% | Loss: 1.587\n",
      "Time elapsed: 4.02 min\n",
      "Epoch: 004/050 | Batch 0000/0352 | Loss: 1.6131\n",
      "Epoch: 004/050 | Batch 0200/0352 | Loss: 1.4481\n",
      "***Epoch: 004/050 | Train. Acc.: 48.900% | Loss: 1.360\n",
      "***Epoch: 004/050 | Valid. Acc.: 49.240% | Loss: 1.367\n",
      "Time elapsed: 5.36 min\n",
      "Epoch: 005/050 | Batch 0000/0352 | Loss: 1.3710\n",
      "Epoch: 005/050 | Batch 0200/0352 | Loss: 1.2681\n",
      "***Epoch: 005/050 | Train. Acc.: 55.542% | Loss: 1.225\n",
      "***Epoch: 005/050 | Valid. Acc.: 55.100% | Loss: 1.246\n",
      "Time elapsed: 6.70 min\n",
      "Epoch: 006/050 | Batch 0000/0352 | Loss: 1.1851\n",
      "Epoch: 006/050 | Batch 0200/0352 | Loss: 1.2468\n",
      "***Epoch: 006/050 | Train. Acc.: 60.782% | Loss: 1.082\n",
      "***Epoch: 006/050 | Valid. Acc.: 58.540% | Loss: 1.123\n",
      "Time elapsed: 8.04 min\n",
      "Epoch: 007/050 | Batch 0000/0352 | Loss: 1.0370\n",
      "Epoch: 007/050 | Batch 0200/0352 | Loss: 1.1411\n",
      "***Epoch: 007/050 | Train. Acc.: 65.060% | Loss: 0.969\n",
      "***Epoch: 007/050 | Valid. Acc.: 63.240% | Loss: 1.023\n",
      "Time elapsed: 9.38 min\n",
      "Epoch: 008/050 | Batch 0000/0352 | Loss: 0.9135\n",
      "Epoch: 008/050 | Batch 0200/0352 | Loss: 1.0559\n",
      "***Epoch: 008/050 | Train. Acc.: 68.411% | Loss: 0.876\n",
      "***Epoch: 008/050 | Valid. Acc.: 66.260% | Loss: 0.946\n",
      "Time elapsed: 10.73 min\n",
      "Epoch: 009/050 | Batch 0000/0352 | Loss: 0.8361\n",
      "Epoch: 009/050 | Batch 0200/0352 | Loss: 1.0417\n",
      "***Epoch: 009/050 | Train. Acc.: 71.067% | Loss: 0.815\n",
      "***Epoch: 009/050 | Valid. Acc.: 67.320% | Loss: 0.908\n",
      "Time elapsed: 12.06 min\n",
      "Epoch: 010/050 | Batch 0000/0352 | Loss: 0.7411\n",
      "Epoch: 010/050 | Batch 0200/0352 | Loss: 0.9193\n",
      "***Epoch: 010/050 | Train. Acc.: 74.264% | Loss: 0.729\n",
      "***Epoch: 010/050 | Valid. Acc.: 70.560% | Loss: 0.847\n",
      "Time elapsed: 13.39 min\n",
      "Epoch: 011/050 | Batch 0000/0352 | Loss: 0.6631\n",
      "Epoch: 011/050 | Batch 0200/0352 | Loss: 0.7807\n",
      "***Epoch: 011/050 | Train. Acc.: 75.558% | Loss: 0.685\n",
      "***Epoch: 011/050 | Valid. Acc.: 71.380% | Loss: 0.840\n",
      "Time elapsed: 14.74 min\n",
      "Epoch: 012/050 | Batch 0000/0352 | Loss: 0.6035\n",
      "Epoch: 012/050 | Batch 0200/0352 | Loss: 0.7203\n",
      "***Epoch: 012/050 | Train. Acc.: 76.527% | Loss: 0.658\n",
      "***Epoch: 012/050 | Valid. Acc.: 71.760% | Loss: 0.856\n",
      "Time elapsed: 16.08 min\n",
      "Epoch: 013/050 | Batch 0000/0352 | Loss: 0.5782\n",
      "Epoch: 013/050 | Batch 0200/0352 | Loss: 0.6597\n",
      "***Epoch: 013/050 | Train. Acc.: 77.604% | Loss: 0.629\n",
      "***Epoch: 013/050 | Valid. Acc.: 71.920% | Loss: 0.859\n",
      "Time elapsed: 17.42 min\n",
      "Epoch: 014/050 | Batch 0000/0352 | Loss: 0.5054\n",
      "Epoch: 014/050 | Batch 0200/0352 | Loss: 0.6732\n",
      "***Epoch: 014/050 | Train. Acc.: 77.451% | Loss: 0.652\n",
      "***Epoch: 014/050 | Valid. Acc.: 71.200% | Loss: 0.883\n",
      "Time elapsed: 18.78 min\n",
      "Epoch: 015/050 | Batch 0000/0352 | Loss: 0.4921\n",
      "Epoch: 015/050 | Batch 0200/0352 | Loss: 0.5878\n",
      "***Epoch: 015/050 | Train. Acc.: 78.100% | Loss: 0.630\n",
      "***Epoch: 015/050 | Valid. Acc.: 71.240% | Loss: 0.907\n",
      "Time elapsed: 20.12 min\n",
      "Epoch: 016/050 | Batch 0000/0352 | Loss: 0.5568\n",
      "Epoch: 016/050 | Batch 0200/0352 | Loss: 0.5079\n",
      "***Epoch: 016/050 | Train. Acc.: 80.038% | Loss: 0.579\n",
      "***Epoch: 016/050 | Valid. Acc.: 71.440% | Loss: 0.927\n",
      "Time elapsed: 21.48 min\n",
      "Epoch: 017/050 | Batch 0000/0352 | Loss: 0.4743\n",
      "Epoch: 017/050 | Batch 0200/0352 | Loss: 0.5260\n",
      "***Epoch: 017/050 | Train. Acc.: 80.993% | Loss: 0.587\n",
      "***Epoch: 017/050 | Valid. Acc.: 71.620% | Loss: 0.972\n",
      "Time elapsed: 22.84 min\n",
      "Epoch: 018/050 | Batch 0000/0352 | Loss: 0.4304\n",
      "Epoch: 018/050 | Batch 0200/0352 | Loss: 0.4014\n",
      "***Epoch: 018/050 | Train. Acc.: 79.338% | Loss: 0.653\n",
      "***Epoch: 018/050 | Valid. Acc.: 69.800% | Loss: 1.060\n",
      "Time elapsed: 24.18 min\n",
      "Epoch: 019/050 | Batch 0000/0352 | Loss: 0.3735\n",
      "Epoch: 019/050 | Batch 0200/0352 | Loss: 0.3300\n",
      "***Epoch: 019/050 | Train. Acc.: 83.069% | Loss: 0.518\n",
      "***Epoch: 019/050 | Valid. Acc.: 73.120% | Loss: 0.991\n",
      "Time elapsed: 25.54 min\n",
      "Epoch: 020/050 | Batch 0000/0352 | Loss: 0.3183\n",
      "Epoch: 020/050 | Batch 0200/0352 | Loss: 0.4193\n",
      "***Epoch: 020/050 | Train. Acc.: 86.129% | Loss: 0.417\n",
      "***Epoch: 020/050 | Valid. Acc.: 74.960% | Loss: 0.961\n",
      "Time elapsed: 26.90 min\n",
      "Epoch: 021/050 | Batch 0000/0352 | Loss: 0.1804\n",
      "Epoch: 021/050 | Batch 0200/0352 | Loss: 0.4772\n",
      "***Epoch: 021/050 | Train. Acc.: 84.840% | Loss: 0.465\n",
      "***Epoch: 021/050 | Valid. Acc.: 73.140% | Loss: 1.030\n",
      "Time elapsed: 28.23 min\n",
      "Epoch: 022/050 | Batch 0000/0352 | Loss: 0.1948\n",
      "Epoch: 022/050 | Batch 0200/0352 | Loss: 0.4274\n",
      "***Epoch: 022/050 | Train. Acc.: 84.476% | Loss: 0.483\n",
      "***Epoch: 022/050 | Valid. Acc.: 73.260% | Loss: 1.068\n",
      "Time elapsed: 29.55 min\n",
      "Epoch: 023/050 | Batch 0000/0352 | Loss: 0.2390\n",
      "Epoch: 023/050 | Batch 0200/0352 | Loss: 0.2604\n",
      "***Epoch: 023/050 | Train. Acc.: 87.033% | Loss: 0.407\n",
      "***Epoch: 023/050 | Valid. Acc.: 74.260% | Loss: 1.054\n",
      "Time elapsed: 30.90 min\n",
      "Epoch: 024/050 | Batch 0000/0352 | Loss: 0.2224\n",
      "Epoch: 024/050 | Batch 0200/0352 | Loss: 0.3437\n",
      "***Epoch: 024/050 | Train. Acc.: 88.540% | Loss: 0.345\n",
      "***Epoch: 024/050 | Valid. Acc.: 74.440% | Loss: 0.996\n",
      "Time elapsed: 32.25 min\n",
      "Epoch: 025/050 | Batch 0000/0352 | Loss: 0.1355\n",
      "Epoch: 025/050 | Batch 0200/0352 | Loss: 0.2179\n",
      "***Epoch: 025/050 | Train. Acc.: 90.258% | Loss: 0.292\n",
      "***Epoch: 025/050 | Valid. Acc.: 75.200% | Loss: 1.034\n",
      "Time elapsed: 33.59 min\n",
      "Epoch: 026/050 | Batch 0000/0352 | Loss: 0.1521\n",
      "Epoch: 026/050 | Batch 0200/0352 | Loss: 0.1664\n",
      "***Epoch: 026/050 | Train. Acc.: 89.593% | Loss: 0.324\n",
      "***Epoch: 026/050 | Valid. Acc.: 74.920% | Loss: 1.139\n",
      "Time elapsed: 34.94 min\n",
      "Epoch: 027/050 | Batch 0000/0352 | Loss: 0.1650\n",
      "Epoch: 027/050 | Batch 0200/0352 | Loss: 0.2342\n",
      "***Epoch: 027/050 | Train. Acc.: 90.831% | Loss: 0.287\n",
      "***Epoch: 027/050 | Valid. Acc.: 75.440% | Loss: 1.242\n",
      "Time elapsed: 36.29 min\n",
      "Epoch: 028/050 | Batch 0000/0352 | Loss: 0.1207\n",
      "Epoch: 028/050 | Batch 0200/0352 | Loss: 0.1094\n",
      "***Epoch: 028/050 | Train. Acc.: 90.716% | Loss: 0.304\n",
      "***Epoch: 028/050 | Valid. Acc.: 74.680% | Loss: 1.334\n",
      "Time elapsed: 37.64 min\n",
      "Epoch: 029/050 | Batch 0000/0352 | Loss: 0.0708\n",
      "Epoch: 029/050 | Batch 0200/0352 | Loss: 0.1305\n",
      "***Epoch: 029/050 | Train. Acc.: 90.467% | Loss: 0.326\n",
      "***Epoch: 029/050 | Valid. Acc.: 73.800% | Loss: 1.357\n",
      "Time elapsed: 38.97 min\n",
      "Epoch: 030/050 | Batch 0000/0352 | Loss: 0.0699\n",
      "Epoch: 030/050 | Batch 0200/0352 | Loss: 0.0874\n",
      "***Epoch: 030/050 | Train. Acc.: 94.009% | Loss: 0.188\n",
      "***Epoch: 030/050 | Valid. Acc.: 75.640% | Loss: 1.134\n",
      "Time elapsed: 40.31 min\n",
      "Epoch: 031/050 | Batch 0000/0352 | Loss: 0.0434\n",
      "Epoch: 031/050 | Batch 0200/0352 | Loss: 0.0705\n",
      "***Epoch: 031/050 | Train. Acc.: 94.442% | Loss: 0.172\n",
      "***Epoch: 031/050 | Valid. Acc.: 76.360% | Loss: 1.172\n",
      "Time elapsed: 41.63 min\n",
      "Epoch: 032/050 | Batch 0000/0352 | Loss: 0.1119\n",
      "Epoch: 032/050 | Batch 0200/0352 | Loss: 0.0542\n",
      "***Epoch: 032/050 | Train. Acc.: 94.224% | Loss: 0.182\n",
      "***Epoch: 032/050 | Valid. Acc.: 76.940% | Loss: 1.169\n",
      "Time elapsed: 42.96 min\n",
      "Epoch: 033/050 | Batch 0000/0352 | Loss: 0.0771\n",
      "Epoch: 033/050 | Batch 0200/0352 | Loss: 0.0898\n",
      "***Epoch: 033/050 | Train. Acc.: 94.258% | Loss: 0.177\n",
      "***Epoch: 033/050 | Valid. Acc.: 76.440% | Loss: 1.189\n",
      "Time elapsed: 44.31 min\n",
      "Epoch: 034/050 | Batch 0000/0352 | Loss: 0.0264\n",
      "Epoch: 034/050 | Batch 0200/0352 | Loss: 0.1268\n",
      "***Epoch: 034/050 | Train. Acc.: 94.207% | Loss: 0.186\n",
      "***Epoch: 034/050 | Valid. Acc.: 76.520% | Loss: 1.160\n",
      "Time elapsed: 45.66 min\n",
      "Epoch: 035/050 | Batch 0000/0352 | Loss: 0.0771\n",
      "Epoch: 035/050 | Batch 0200/0352 | Loss: 0.0835\n",
      "***Epoch: 035/050 | Train. Acc.: 94.469% | Loss: 0.187\n",
      "***Epoch: 035/050 | Valid. Acc.: 76.440% | Loss: 1.249\n",
      "Time elapsed: 47.01 min\n",
      "Epoch: 036/050 | Batch 0000/0352 | Loss: 0.0535\n",
      "Epoch: 036/050 | Batch 0200/0352 | Loss: 0.0453\n",
      "***Epoch: 036/050 | Train. Acc.: 96.549% | Loss: 0.105\n",
      "***Epoch: 036/050 | Valid. Acc.: 77.480% | Loss: 1.154\n",
      "Time elapsed: 48.35 min\n",
      "Epoch: 037/050 | Batch 0000/0352 | Loss: 0.0167\n",
      "Epoch: 037/050 | Batch 0200/0352 | Loss: 0.0867\n"
     ]
    }
   ],
   "source": [
    "_ = train_classifier_simple_v1(num_epochs=num_epochs, model=model_parallel, \n",
    "                               optimizer=optimizer, device=torch.device('cuda:0'), \n",
    "                               train_loader=train_loader, valid_loader=valid_loader, \n",
    "                               logging_interval=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, the Training is slower as before. But this is expected because the main selling point of pipeline parallelism is to utilize more GPUs due to memory contraints not to speed up training."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
